{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#| eval: false\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "---\n",
    "skip_exec: true\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp callback.comet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from __future__ import annotations\n",
    "\n",
    "import tempfile\n",
    "\n",
    "from fastai.basics import *\n",
    "from fastai.learner import Callback"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Comet.ml\n",
    "\n",
    "> Integration with [Comet.ml](https://www.comet.ml/)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Registration"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. Create account: [comet.ml/signup](https://www.comet.ml/signup).\n",
    "2. Export API key to the environment variable (more help [here](https://www.comet.ml/docs/v2/guides/getting-started/quickstart/#get-an-api-key)). In your terminal run:\n",
    "\n",
    "```\n",
    "export COMET_API_KEY='YOUR_LONG_API_TOKEN'\n",
    "```\n",
    "\n",
    "or include it in your `./comet.config` file (**recommended**). More help is [here](https://www.comet.ml/docs/v2/guides/tracking-ml-training/jupyter-notebooks/#set-your-api-key-and-project-name)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Installation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. You need to install neptune-client. In your terminal run:\n",
    "\n",
    "```\n",
    "pip install comet_ml\n",
    "```\n",
    "\n",
    "or (alternative installation using conda). In your terminal run:\n",
    "\n",
    "```\n",
    "conda install -c anaconda -c conda-forge -c comet_ml comet_ml\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How to use?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Key is to create the callback `CometMLCallback` before you create `Learner()` like this:\n",
    "\n",
    "```\n",
    "from fastai.callback.comet import CometMLCallback\n",
    "\n",
    "comet_ml_callback = CometCallback('PROJECT_NAME')  # specify project\n",
    "\n",
    "learn = Learner(dls, model,\n",
    "                cbs=comet_ml_callback\n",
    "                )\n",
    "\n",
    "learn.fit_one_cycle(1)\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import comet_ml"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class CometCallback(Callback):\n",
    "    \"Log losses, metrics, model weights, model architecture summary to neptune\"\n",
    "    order = Recorder.order + 1\n",
    "\n",
    "    def __init__(self, project_name, log_model_weights=True):\n",
    "        self.log_model_weights = log_model_weights\n",
    "        self.keep_experiment_running = keep_experiment_running\n",
    "        self.project_name = project_name\n",
    "        self.experiment = None\n",
    "\n",
    "    def before_fit(self):\n",
    "        try:\n",
    "            self.experiment = comet_ml.Experiment(project_name=self.project_name)\n",
    "        except ValueError:\n",
    "            print(\"No active experiment\")\n",
    "\n",
    "        try:\n",
    "            self.experiment.log_parameter(\"n_epoch\", str(self.learn.n_epoch))\n",
    "            self.experiment.log_parameter(\"model_class\", str(type(self.learn.model)))\n",
    "        except:\n",
    "            print(f\"Did not log all properties.\")\n",
    "\n",
    "        try:\n",
    "            with tempfile.NamedTemporaryFile(mode=\"w\") as f:\n",
    "                with open(f.name, \"w\") as g:\n",
    "                    g.write(repr(self.learn.model))\n",
    "                self.experiment.log_asset(f.name, \"model_summary.txt\")\n",
    "        except:\n",
    "            print(\"Did not log model summary. Check if your model is PyTorch model.\")\n",
    "\n",
    "        if self.log_model_weights and not hasattr(self.learn, \"save_model\"):\n",
    "            print(\n",
    "                \"Unable to log model to Comet.\\n\",\n",
    "            )\n",
    "\n",
    "    def after_batch(self):\n",
    "        # log loss and opt.hypers\n",
    "        if self.learn.training:\n",
    "            self.experiment.log_metric(\"batch__smooth_loss\", self.learn.smooth_loss)\n",
    "            self.experiment.log_metric(\"batch__loss\", self.learn.loss)\n",
    "            self.experiment.log_metric(\"batch__train_iter\", self.learn.train_iter)\n",
    "            for i, h in enumerate(self.learn.opt.hypers):\n",
    "                for k, v in h.items():\n",
    "                    self.experiment.log_metric(f\"batch__opt.hypers.{k}\", v)\n",
    "\n",
    "    def after_epoch(self):\n",
    "        # log metrics\n",
    "        for n, v in zip(self.learn.recorder.metric_names, self.learn.recorder.log):\n",
    "            if n not in [\"epoch\", \"time\"]:\n",
    "                self.experiment.log_metric(f\"epoch__{n}\", v)\n",
    "            if n == \"time\":\n",
    "                self.experiment.log_text(f\"epoch__{n}\", str(v))\n",
    "\n",
    "        # log model weights\n",
    "        if self.log_model_weights and hasattr(self.learn, \"save_model\"):\n",
    "            if self.learn.save_model.every_epoch:\n",
    "                _file = join_path_file(\n",
    "                    f\"{self.learn.save_model.fname}_{self.learn.save_model.epoch}\",\n",
    "                    self.learn.path / self.learn.model_dir,\n",
    "                    ext=\".pth\",\n",
    "                )\n",
    "            else:\n",
    "                _file = join_path_file(\n",
    "                    self.learn.save_model.fname,\n",
    "                    self.learn.path / self.learn.model_dir,\n",
    "                    ext=\".pth\",\n",
    "                )\n",
    "            self.experiment.log_asset(_file)\n",
    "\n",
    "    def after_fit(self):\n",
    "        try:\n",
    "            self.experiment.end()\n",
    "        except:\n",
    "            print(\"No neptune experiment to stop.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "### CometCallback\n",
       "\n",
       ">      CometCallback (project_name, log_model_weights=True)\n",
       "\n",
       "Log losses, metrics, model weights, model architecture summary to neptune"
      ],
      "text/plain": [
       "<nbdev.showdoc.BasicMarkdownRenderer>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(CometCallback)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
